#fcn_ntk_init_posterior is a brutal-force grid search finding the generalization bound under fully connected network with MNIST. 
import torch
from pbb.utils import runexp
import argparse
import numpy as np
parse= argparse.ArgumentParser(description="haha")
parse.add_argument('--number_for_prior',type=float,default=0.2,help='number_for_prior')
parse.add_argument('--kl_divergent',type=float,default=1.0,help='number_for_posterior')
args=parse.parse_args() 

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

BATCH_SIZE = 250
TRAIN_EPOCHS = 10
DELTA = 0.025
DELTA_TEST = 0.01
PRIOR = 'learnt'

SIGMAPRIOR = 0.1
PMIN = 1e-5
KL_PENALTY = args.kl_divergent
LEARNING_RATE = 1
MOMENTUM = 0.95
LEARNING_RATE_PRIOR = 1.9
MOMENTUM_PRIOR = 0.99

# note the number of MC samples used in the paper is 150.000, which usually takes a several hours to compute
MC_SAMPLES = 150000
perc_prior = args.number_for_prior 
#shot_per_class = args.kl_divergent
prior_epochs = 40 

# note all of these running examples have different settings!
risk_01,ens_err,post_err,stch_err = runexp('mnist', 'fquad', PRIOR, 'fcn', SIGMAPRIOR, PMIN, LEARNING_RATE, MOMENTUM, LEARNING_RATE_PRIOR, MOMENTUM_PRIOR, delta=DELTA, delta_test=DELTA_TEST, mc_samples=MC_SAMPLES, train_epochs=TRAIN_EPOCHS, device=DEVICE, perc_train=1.0, verbose=True, perc_prior=perc_prior, prior_epochs = prior_epochs, kl_penalty = KL_PENALTY, dropout_prob=0.2)
np.save(str(KL_PENALTY)+'_'+str(SIGMAPRIOR)+'_'+str(perc_prior)+"_fcn_.npy",[risk_01,ens_err,post_err,stch_err])

